#include "irradiation_model.h"
#include <memory>

// time issue
#include <chrono>

#include "cxxopts.hpp"

#include <ctime>

bool debug_on = false;


valueType *read_init_state(const char *filename, uint &Nx, uint &Ny, uint &n_grains, uint &n_step)
{
    FILE *inp = fopen(filename, "r");
    fscanf(inp, "%u,%u,%u", &Nx, &Ny, &n_grains, &n_step);
    valueType *mtx = new valueType[Nx * Ny * n_grains];
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint x = 0; x < Nx; ++x)
        {
            fscanf(inp, "%lf", mtx + pg * Nx * Ny + x * Ny);
            for (uint y = 1; y < Ny; ++y)
                fscanf(inp, ",%lf", mtx + pg * Nx * Ny + x * Ny + y);
        }
    }

    fclose(inp);
    return mtx;
}

void print_img_in_csv(valueType *img, const char *filename, uint Nx, uint Ny,
                      uint n_grains)
{
    FILE *oup = fopen(filename, "w");
    fprintf(oup, "%u,%u,%u\n", Nx, Ny, n_grains);
    for (uint pg = 0; pg < n_grains; ++pg)
    {
        for (uint i = 0; i < Nx; ++i)
        {
            fprintf(oup, "%lf", img[pg * Nx * Ny + i * Ny]);
            for (uint j = 1; j < Ny; ++j)
            {
                fprintf(oup, ",%lf", img[pg * Nx * Ny + i * Ny + j]);
            }
            fputc('\n', oup);
        }
    }
    fclose(oup);
}

class Args
{
public:
    uint nsteps;
    string input;
    string output;
    string bucket_output;
    uint lshL, lshK; // Nx, Ny;
    double lshr;
};

Args *parse_args(int argc, const char *argv[])
{
    try
    {
        Args *args = new Args;

        cxxopts::Options options(argv[0], " - test forward simulation of grain growth.");
        options
            .positional_help("[optional args]")
            .show_positional_help();

        options
            .set_width(70)
            .set_tab_expansion()
            .allow_unrecognised_options()
            .add_options()("s,nsteps", "Number of steps of simulation (default=100)", cxxopts::value<int>(), "N")("o,output", "Output file (default=grain.out)", cxxopts::value<std::string>(), "FILE")("i,input", "Input file (default=grain.in)", cxxopts::value<std::string>(), "FILE")("lshK", "K for LSH (default=1)", cxxopts::value<int>(), "INT")("lshL", "L for LSH (default=1)", cxxopts::value<int>(), "INT")("lshr", "r for LSH (default=1e-4)", cxxopts::value<float>(), "FLOAT")("bucket_output", "Output file of the bucket information (default=bucket.out)", cxxopts::value<std::string>(), "FILE")("h,help", "Print help")
#ifdef CXXOPTS_USE_UNICODE
                ("unicode", u8"A help option with non-ascii: à. Here the size of the"
                            " string should be correct")
#endif
            ;
        //("Nx", "size of x-axis (default=64)", cxxopts::value<int>(), "INT")
        //("Ny", "size of y-axis (default=64)", cxxopts::value<int>(), "INT")

        auto result = options.parse(argc, argv);

        if (result.count("help"))
        {
            std::cout << options.help({"", "Group"}) << std::endl;
            exit(0);
        }

        std::cout << "[Parse Args]" << std::endl;

        if (result.count("nsteps"))
        {
            std::cout << "  nsteps = " << result["nsteps"].as<int>() << std::endl;
            args->nsteps = (uint)result["nsteps"].as<int>();
        }
        else
        {
            args->nsteps = 100;
        }

        if (result.count("output"))
        {
            std::cout << "  output = " << result["output"].as<std::string>()
                      << std::endl;
            args->output = result["output"].as<std::string>();
        }
        else
        {
            args->output = "grain.out";
        }

        if (result.count("input"))
        {
            std::cout << "  input = " << result["input"].as<std::string>()
                      << std::endl;
            args->input = result["input"].as<std::string>();
        }
        else
        {
            args->input = "grain.in";
        }

        if (result.count("bucket_output"))
        {
            std::cout << "  bucket_output = " << result["bucket_output"].as<std::string>()
                      << std::endl;
            args->bucket_output = result["bucket_output"].as<std::string>();
        }
        else
        {
            args->bucket_output = "bucket.out";
        }
        if (result.count("lshK"))
        {
            std::cout << "  lshK = " << result["lshK"].as<int>()
                      << std::endl;
            args->lshK = (uint)result["lshK"].as<int>();
        }
        else
        {
            args->lshK = 1;
        }

        if (result.count("lshL"))
        {
            std::cout << "  lshL = " << result["lshL"].as<int>()
                      << std::endl;
            args->lshL = (uint)result["lshL"].as<int>();
        }
        else
        {
            args->lshL = 1;
        }

        if (result.count("lshr"))
        {
            std::cout << "  lshr = " << result["lshr"].as<float>()
                      << std::endl;
            args->lshr = (double)result["lshr"].as<float>();
        }
        else
        {
            args->lshr = 1e-4;
        }

        auto arguments = result.arguments();
        std::cout << "  Saw " << arguments.size() << " arguments" << std::endl;

        std::cout << "[End of Parse Args]" << std::endl;

        /*
    if (result.count("Nx"))
    {
      std::cout << "  Nx = " << result["Nx"].as<int>()
        << std::endl;
      args->Nx = (uint)result["Nx"].as<int>();
    }else{
      args->Nx = 64;
    }
    if (result.count("Ny"))
    {
      std::cout << "  Ny = " << result["Ny"].as<int>()
        << std::endl;
      args->Ny = (uint)result["Ny"].as<int>();
    }else{
      args->Ny = 64;
    }
    */

        return args;
    }
    catch (const cxxopts::OptionException &e)
    {
        std::cout << "error parsing options: " << e.what() << std::endl;
        exit(1);
    }
}

int main(int argc, const char *argv[])
{

    Args *args = parse_args(argc, argv);

    // def parameters
    uint Nx = 64; //1024;   these will be changed later.
    uint Ny = 64; //1024;
    uint n_grains = 2;
    uint n_step = 500;

    uint lshK = args->lshK;
    uint lshL = args->lshL;
    valueType lsh_r = args->lshr;
    uint nsteps = args->nsteps;

    valueType h = 0.5;

    valueType A = 1.0;
    valueType B = 1.0;
    valueType L = 5.0;
    valueType kappa = 0.1;

    valueType dtime = 0.05;
    valueType ttime = 0.0;

    valueType init_L = 2.0; // try to learn to 5.0
    valueType init_A = 2.0; // try to learn to 1.0
    valueType init_B = 3.0; // try to learn to 1.0
    valueType init_kappa = 0.9; // try to learn to 0.1

    double lr = 1e-1;
    uint start_skip = 1;
    uint skip_step = 30;
    uint epoch = 500;
    double lambda = 10;

    char* data_path = "../data/grain_growth_all_data_1";
    GrainGrowthDataset dataset(data_path, start_skip, skip_step);

    Nx = dataset.Nx;
    Ny = dataset.Ny;
    n_grains = dataset.n_grains;
    n_step = dataset.n_step;

    if (debug_on) {
        std::cout << "finish data loading" << std::endl;
    }


    double min_loss = 1000.0;

    int seed = 4321;
    torch::manual_seed(seed);
    torch::cuda::manual_seed(seed);
    torch::autograd::AnomalyMode::set_enabled(true); 

    ggTimeStep gg_model(dtime, h, h, 1e-3, 1, init_L, init_A, init_B, init_kappa, Nx, Ny, n_grains, h);

    torch::Device device(torch::kCPU);

    gg_model->to(device);

    if (debug_on) {
        std::cout << "finish ts model init" << std::endl;
    }

    torch::nn::MSELoss mse(torch::nn::MSELossOptions(torch::kSum));

    torch::optim::Adam optimizer(gg_model->parameters(), torch::optim::AdamOptions(lr));

    if (debug_on) {
        std::cout << "finish optim mseloss init" << std::endl;
    }



    for (int i = 0; i < epoch; ++ i) {
        double loss = 0.0;
        int total_size = 0;
        printf("epoch:\t%d\n", i);
        // for (int index = start_skip; index < start_skip + 10; ++index) {        // for testing
        for (int index = start_skip; index < dataset.get_len() - 10; ++index) {
            ReturnItem rt = dataset.get_item(index);

            valueType* eta1_start = rt.data.eta1;
            valueType* eta2_start = rt.data.eta2;
            torch::Tensor eta_1_t = torch::from_blob(eta1_start, {Nx, Ny}, torch::dtype(torch::kFloat64)).clone();
            torch::Tensor eta_2_t = torch::from_blob(eta2_start, {Nx, Ny}, torch::dtype(torch::kFloat64)).clone();

            if (debug_on) {
                std::cout << "success get data" << std::endl;
                // std::cout << "frame1 size: " << frame1.sizes() << std::endl;
            }

            torch::Tensor concate_vals;
            auto start = std::chrono::high_resolution_clock::now();
            for (int j = 0; j < dataset.skip_step; ++ j) {
                concate_vals = gg_model->forward(eta_1_t, eta_2_t);
                eta_1_t = concate_vals.index({0});
                eta_2_t = concate_vals.index({0});
                // eta_1_t.unsqueeze_(0);
                // eta_2_t.unsqueeze_(0);
                // std::cout << "concate_vals size: " << concate_vals.sizes() << std::endl;
                // std::cout << "eta1 size: " << eta_1_t.sizes() << ", eta2 size: " << eta_2_t.sizes() << std::endl;
            }
            auto stop = std::chrono::high_resolution_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start);
            std::cout << "time of gg model forward: " << duration.count() << "ms in " << dataset.skip_step << "steps" << std::endl;

            if (debug_on) {
                std::cout << "success forward in ts model" << std::endl;
            }

            valueType* eta1_ref_start = rt.ref.eta1_ref;
            valueType* eta2_ref_start = rt.ref.eta2_ref;
            torch::Tensor eta_1_ref_t = torch::from_blob(eta1_ref_start, {Nx, Ny}, torch::dtype(torch::kFloat64)).clone();
            torch::Tensor eta_2_ref_t = torch::from_blob(eta2_ref_start, {Nx, Ny}, torch::dtype(torch::kFloat64)).clone();

            if (debug_on) {
                std::cout << "success get ref data" << std::endl;
                // std::cout << "frame1 size: " << frame1.sizes() << std::endl;
            }

            torch::Tensor eta_1_batch_loss = lambda * mse->forward(eta_1_t, eta_1_ref_t);
            torch::Tensor eta_2_batch_loss = lambda * mse->forward(eta_2_t, eta_2_ref_t);

            torch::Tensor batch_loss = eta_1_batch_loss + eta_2_batch_loss;

            if (debug_on) {
                std::cout << "success get loss" << std::endl;
            }

            optimizer.zero_grad();

            auto start_back = std::chrono::high_resolution_clock::now();
            batch_loss.backward();
            auto stop_back = std::chrono::high_resolution_clock::now();
            auto duration_back = std::chrono::duration_cast<std::chrono::milliseconds>(stop_back - start_back);
            std::cout << "time of ts model backward: " << duration_back.count() << "ms in " << dataset.skip_step << "steps" << std::endl;

            optimizer.step();

            int this_size = 1;
            loss += (batch_loss.item<valueType>());
            if (true) {
                std::cout << "batch loss: " << (batch_loss.item<valueType>()) << std::endl;
                std::cout << "eta1 loss: " << eta_1_batch_loss.item<valueType>() << ", eta2 loss: " \ 
                            << eta_2_batch_loss.item<valueType>() << ", eta loss: " << std::endl;
                {
                    for (const auto& pair : gg_model->named_parameters()) {
                        std::cout << pair.key() << "'s grad: " << pair.value().grad() << std::endl;
                        std::cout << pair.key() << "'s value: " << pair.value() << std::endl;
                    }
                }
            }
            total_size += this_size;
        }

        loss /= total_size;
        printf("loss:\t%.8f\n", loss);

        if (loss < min_loss) {
            min_loss = loss;
            // save ts_model
            // save video2pf 
        }
        else {
            printf("Above min_loss\n");
        }
    }

    return 0;
}
